[JAX] Support for cuDNN-backed flex attention#2985
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adds experimental cuDNN-frontend-backed flex attention (
Confidence Score: 5/5Safe to merge as an experimental feature; all flagged items are non-blocking quality improvements with no correctness impact. The core forward and backward graph building, caching, FFI dispatch, and Flax plumbing are all structurally correct. Cache key stability, UID ordering, and pytree gradient structure are handled properly. The findings are race conditions that produce at worst redundant work (not wrong results) and a shutdown-order concern for the thread-local cuDNN handle that matches patterns already present elsewhere in the codebase. transformer_engine/jax/csrc/extensions/attention.cpp (double-checked locking in GetScoreModGraph, thread-local handle destructor ordering) and transformer_engine/jax/cpp_extensions/flex_attention.py (Python-level cache lock). Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant fused_attn
participant ScoreMod as "_fused_attn_score_mod"
participant FlexPy as "flex_attention.py"
participant FFI as "FFI/XLA"
participant Cpp as "C++ Handler"
participant Cache as "cuDNN Graph Cache"
User->>fused_attn: "call with score_mod callback"
fused_attn->>fused_attn: "validate_fused_attn_score_mod()"
fused_attn->>FlexPy: "make_fused_attn_score_mod_config()"
fused_attn->>ScoreMod: "custom_vjp forward"
Note over ScoreMod,FlexPy: JAX Tracing Phase
ScoreMod->>FlexPy: "fused_attn_score_mod_fwd()"
FlexPy->>FlexPy: "check _score_mod_graph_cache"
alt cache miss
FlexPy->>FlexPy: "_build_score_mod_fwd_graph()"
FlexPy->>FlexPy: "store in _score_mod_graph_cache"
end
FlexPy->>FFI: "ffi.ffi_call(serialized_graph, uids)"
Note over FFI,Cache: XLA Execution Phase
FFI->>Cpp: "FusedAttnScoreModForwardFFI(stream, q, k, v)"
Cpp->>Cache: "GetScoreModGraph(stream, attrs)"
alt C++ cache miss
Cache->>Cache: "graph->deserialize(handle, data)"
Cache->>Cache: "store shared_ptr in map"
end
Cpp->>Cpp: "graph->execute(handle, variant_pack)"
Cpp-->>FFI: "output, stats, workspace"
Note over ScoreMod,FlexPy: Backward pass
ScoreMod->>FlexPy: "fused_attn_score_mod_bwd(qkv, o, dO, stats)"
FlexPy->>FFI: "ffi.ffi_call(serialized_bwd_graph)"
FFI->>Cpp: "FusedAttnScoreModBackwardFFI(...)"
Cpp-->>FFI: "dq, dk, dv"
Reviews (9): Last reviewed commit: "Skip softcap score-mod test before SM90" | Re-trigger Greptile |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Adds experimental JAX fused-attention
score_modsupport through cuDNN frontend SDPA graphs.This introduces a
score_mod(graph, score, tensors)callback path forfused_attn, plus optionalscore_mod_bprop(graph, dscore, tensors)support for backward. The Python side builds and serializes cuDNN frontend forward/backward graphs, caches graph metadata with stable callback keys, supports auxiliary tensor operands, and supports Python/NumPy scalar operands as cuDNN pass-by-value tensors. The C++ JAX extension deserializes and caches the graphs per device, then executes them through new forward/backward FFI handlers.The Flax API now plumbs
score_modthroughDotProductAttention,MultiHeadAttention, andTransformerLayer. Packed QKV/KV layouts are unpacked to the separate BSHD layout when score modification is requested.Users are responsible for supplying a mathematically correct
score_mod_bpropfor the correspondingscore_mod; Transformer Engine wires the callback into the cuDNN graph but does not validate gradient semantics.Current score_mod limitations:
BSHD_BSHD_BSHDQ/K/V tensors only.Fixes # (issue)
#2492
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: